import numpy as np
import torch.nn as nn
import torch.nn.functional as f
import torch.optim as optim
import torch
from network.predict_net import Predict_mse,Predict_combine_mse
import torch.nn as nn
import torch as th
import torch
from torch.autograd import Variable
from network.norm import DynamicNorm
from sklearn.cluster import AgglomerativeClustering
import os
from scipy.stats import wasserstein_distance
from torch.distributions import kl_divergence
import torch.distributions as D


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    torch.set_num_threads(8)
class pre_AC(nn.Module):
    # Because all the agents share the same network, input_shape=obs_shape+n_actions+n_agents
    def __init__(self, input_shape,out_dim, args):
        super(pre_AC, self).__init__()
        self.args = args
        self.n_agents=args.n_agents
        self.pre_mlp = nn.ModuleList([Predict_mse(input_shape+args.rnn_hidden_dim, 128, out_dim, False)
                                       for _ in range(args.n_agents)])

    def forward(self, h_obs, obs,train_label=False):
        if train_label:
            h_obs=h_obs.reshape(-1,self.n_agents,obs.shape[1],h_obs.shape[-1])
            out = torch.cat([self.pre_mlp[i](h_obs[:,i].unsqueeze(1)) for i in range(self.n_agents)],dim=1)
            out=out.reshape(-1,out.shape[-1])

        else:
            with torch.no_grad():
                out=torch.cat([self.pre_mlp[i](h_obs[i].unsqueeze(0)) for i in range(self.n_agents)])
        return out
class RNN_coma(nn.Module):
    # 不加通信
    def __init__(self, input_shape, args):
        super(RNN_coma, self).__init__()
        self.args = args
        self.input_shape = input_shape
        self.n_agents=args.n_agents
        self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim)
        # self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim)
        self.rnn = nn.GRU(
            input_size=args.rnn_hidden_dim,
            num_layers=1,
            hidden_size=args.rnn_hidden_dim,
            batch_first=True,
        )
        self.fc2 = nn.Linear(args.rnn_hidden_dim, args.n_actions)

    def init_hidden(self):
        # make hidden states on same device as model
        return self.fc1.weight.new(1, self.args.rnn_hidden_dim).zero_()

    def forward(self, obs, hidden_state):
        excute_label = False
        if len(hidden_state.shape) == 2:
            hidden_state = hidden_state.unsqueeze(0)  # 保证尺寸(1,3,64)
        if obs.shape[0]== self.n_agents:
            excute_label = True

        #########################################

        obs_c = obs.view(-1, obs.shape[-1])

        x = f.relu(self.fc1(obs_c))
        x = x.reshape(obs.shape[0], 1, -1)
        h_in = hidden_state
        gru_out, _ = self.rnn(x, h_in)
        gru_out_c = gru_out.reshape(-1, gru_out.shape[-1])
        q = self.fc2(gru_out_c)
        q = q.reshape(obs.shape[0],1, -1)  # 415
        if excute_label:
            return q.squeeze(),gru_out.permute(1,0,2)
        return q, gru_out
    def update(self, inputs,vae_mu,vae_sigma, mask):
        pass

class RNN_quick(nn.Module):
    # 不加通信
    def __init__(self, input_shape, args):
        super(RNN_quick, self).__init__()
        self.args = args
        self.input_shape = input_shape
        self.n_agents=args.n_agents
        self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim)
        setup_seed(args.seed)

        # self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim)
        self.rnn = nn.GRU(
            input_size=args.rnn_hidden_dim,
            num_layers=1,
            hidden_size=args.rnn_hidden_dim,
            batch_first=True,
        )
        self.fc2 = nn.Linear(args.rnn_hidden_dim, args.n_actions)

    def init_hidden(self):
        # make hidden states on same device as model
        return self.fc1.weight.new(1, self.args.rnn_hidden_dim).zero_()

    def forward(self, obs, hidden_state):
        excute_label = False
        if len(hidden_state.shape) == 2:
            hidden_state = hidden_state.unsqueeze(0)  # 保证尺寸(1,3,64)
        if obs.shape[0]== self.n_agents:
            excute_label = True

        #########################################

        obs_c = obs.view(-1, obs.shape[-1])

        x = f.relu(self.fc1(obs_c))
        x = x.reshape(obs.shape[0], obs.shape[1], -1)
        h_in = hidden_state
        gru_out, _ = self.rnn(x, h_in)
        gru_out_c = gru_out.reshape(-1, gru_out.shape[-1])
        q = self.fc2(gru_out_c)
        q = q.reshape(obs.shape[0], obs.shape[1], -1)  # 415
        if excute_label:
            return q.squeeze(),gru_out.permute(1,0,2)
        return q, gru_out
    def update(self, inputs,vae_mu,vae_sigma, mask):
        pass

class RNN_future_s(nn.Module):
    # 不加通信
    def __init__(self, input_shape, args):
        super(RNN_future_s, self).__init__()
        self.args = args
        setup_seed(args.seed)
        self.input_shape = input_shape
        self.n_agents=args.n_agents
        self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim)

        # self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim)
        self.rnn = nn.GRU(
            input_size=args.rnn_hidden_dim,
            num_layers=1,
            hidden_size=args.rnn_hidden_dim,
            batch_first=True,
        )
        self.fc2 = nn.Linear(args.rnn_hidden_dim, args.n_actions)


        self.indi_net = nn.Sequential(nn.Linear(args.rnn_hidden_dim*1,args.rnn_hidden_dim), nn.ReLU(inplace=True))
        self.indi_mu = nn.Linear(args.rnn_hidden_dim, args.indi_latent_dim)
        self.indi_lnsigma2 = nn.Linear(args.rnn_hidden_dim, args.indi_latent_dim)

        self.params_net = nn.Sequential(nn.Linear(args.indi_latent_dim*2,args.rnn_hidden_dim), nn.ReLU(inplace=True))
        self.params_dis_net = nn.Sequential(nn.Linear(args.indi_latent_dim*1,args.rnn_hidden_dim), nn.ReLU(inplace=True))

        self.fc2_w_nn = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim * args.n_actions)
        self.fc2_b_nn = nn.Linear(args.rnn_hidden_dim, args.n_actions)

        self.infer_net_s = nn.Sequential(nn.Linear(args.rnn_hidden_dim * 2, args.rnn_hidden_dim), nn.ReLU(inplace=True))
        self.infer_mu_s = nn.Linear(args.rnn_hidden_dim, args.indi_latent_dim)
        self.infer_lnsigma_s = nn.Linear(args.rnn_hidden_dim, args.indi_latent_dim)
        self.pre_rew_s=nn.Linear(args.indi_latent_dim, 1)


        self.infer_net_obs = nn.Sequential(nn.Linear(args.rnn_hidden_dim * 2, args.rnn_hidden_dim), nn.ReLU(inplace=True))
        self.infer_mu_obs = nn.Linear(args.rnn_hidden_dim, args.indi_latent_dim)
        self.infer_lnsigma_obs = nn.Linear(args.rnn_hidden_dim, args.indi_latent_dim)
        self.pre_rew_obs=nn.Linear(args.indi_latent_dim, 1)


        self.fc_state = nn.Sequential(nn.Linear(args.state_shape+args.n_agents*args.n_actions,args.rnn_hidden_dim), nn.ReLU(inplace=True))
        self.fc_obs = nn.Sequential(nn.Linear(input_shape+args.n_actions,args.rnn_hidden_dim), nn.ReLU(inplace=True))

        # self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim)
        self.rnn_state = nn.GRU(
            input_size=args.rnn_hidden_dim,
            num_layers=1,
            hidden_size=args.rnn_hidden_dim,
            batch_first=True,
        )
        self.rnn_obs = nn.GRU(
            input_size=args.rnn_hidden_dim,
            num_layers=1,
            hidden_size=args.rnn_hidden_dim,
            batch_first=True,
        )
        self.fc_obs_state = nn.Sequential(nn.Linear(args.state_shape+args.n_actions,args.rnn_hidden_dim), nn.ReLU(inplace=True))
        # self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim)
        self.obs_wb = nn.Linear(input_shape, args.rnn_hidden_dim)
        self.obs_w_nn = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim **2)
        self.obs_b_nn = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim)
        self.rnn_pastobs = nn.GRU(
            input_size=args.rnn_hidden_dim,
            num_layers=1,
            hidden_size=args.rnn_hidden_dim,
            batch_first=True,
        )
        self.norm=DynamicNorm(input_shape, only_for_last_dim=True, exclude_one_hot=True, exclude_nan=True)
    def init_hidden(self):
        # make hidden states on same device as model
        return self.fc1.weight.new(1, self.args.rnn_hidden_dim).zero_()
    def flipUp(self,s_in,mask):
        mask=mask.repeat(1,1,s_in.shape[-1])
        s_new=s_in.clone().flip(dims=[1])#.cpu().numpy()
        s_flip=torch.zeros_like(s_in)
        mask_flip=mask.flip(dims=[1])
        s_flip[mask.bool()]=s_new[mask_flip.bool()]
        return s_flip
    def past_obs_h(self, obs,b,a):

        x=self.obs_wb(obs)
        fc2_w = self.obs_w_nn(x) # b*a,e*n_actions
        fc2_b = self.obs_b_nn(x)# b*a,n_actions
        fc2_w = fc2_w.view(b*a,self.args.rnn_hidden_dim,self.args.rnn_hidden_dim)# b*a, e,n_actions
        fc2_b = fc2_b.view(b*a,1,self.args.rnn_hidden_dim)# b*a, 1,n_actions

        return fc2_w,fc2_b
    def future_obs_h(self, obs):
        hidden_state=torch.zeros((1, obs.shape[0], self.args.rnn_hidden_dim)).to(obs.device)
        mask = torch.any(obs.bool(), dim=-1).float().unsqueeze(-1)
        #s_flip=s
        obs_h=self.flipUp(obs.clone(), mask)
        obs_h=self.fc_obs(obs_h)
        obs_h,_=self.rnn_obs(obs_h,hidden_state)
        obs_h=self.flipUp(obs_h,mask)
        return obs_h.reshape(-1, obs_h.shape[-1])
    def future_s_h(self, s, obs):
        hidden_state=torch.zeros((1,s.shape[0], self.args.rnn_hidden_dim)).to(s.device)
        mask = torch.any(s.bool(), dim=-1).float().unsqueeze(-1)
        #s_flip=s
        s_h=self.flipUp(s.clone(),mask)
        s_h=self.fc_state(s_h)
        s_h,_=self.rnn_state(s_h,hidden_state)
        s_h=self.flipUp(s_h,mask)
        return s_h.unsqueeze(1).repeat(1,self.n_agents,1,1).reshape(-1, s_h.shape[-1])
    def get_s_MI(self,gru_out_c,s_h,latent_embed,b,a):
        latent_infer = self.infer_net_s(th.cat([gru_out_c.detach(), s_h], dim=1))
        infer_mu_s = self.infer_mu_s(latent_infer)  # b*a, indi_latent_dim
        infer_lnsigma_s = self.infer_lnsigma_s(latent_infer)  # b*a, indi_latent_dim
        latent_infer_embed_s = D.Normal(infer_mu_s.detach(), th.exp(infer_lnsigma_s.detach() / 2))
        rew_latent_dis = self.reparametrize(infer_mu_s, infer_lnsigma_s)
        s_rew = self.pre_rew_s(rew_latent_dis).reshape(-1, self.n_agents, a, 1).mean(dim=1)
        r_MI = kl_divergence(latent_embed, latent_infer_embed_s).sum(dim=-1).reshape(b, a)
        return r_MI,s_rew
    def get_obs_MI(self, gru_out_c, obs_h, latent_embed, b, a):
        latent_infer = self.infer_net_obs(th.cat([gru_out_c.detach(), obs_h], dim=1))
        infer_mu_obs = self.infer_mu_obs(latent_infer)  # b*a, indi_latent_dim
        infer_lnsigma_obs = self.infer_lnsigma_obs(latent_infer)  # b*a, indi_latent_dim
        latent_infer_embed_obs = D.Normal(infer_mu_obs.detach(), th.exp(infer_lnsigma_obs.detach() / 2))
        rew_latent_dis = self.reparametrize(infer_mu_obs, infer_lnsigma_obs)
        obs_rew = self.pre_rew_obs(rew_latent_dis).reshape(-1, self.n_agents, a, 1).mean(dim=1)
        r_MI = kl_divergence(latent_embed, latent_infer_embed_obs).sum(dim=-1).reshape(b, a)
        return r_MI,obs_rew
    def forward(self, obs, hidden_state,s=None,u=None):
        b, a, e = obs.size() # 平行环境，智能体个数，观测数量

        excute_label = False
        if len(hidden_state.shape) == 2:
            hidden_state = hidden_state.unsqueeze(0)  # 保证尺寸(1,3,64)
        if obs.shape[0]== self.n_agents:
            excute_label = True

        #########################################
        r_MI=None
        s_rew=None
        obs_rew=None


        obs_c = obs.view(-1, obs.shape[-1])
        x = f.relu(self.fc1(obs_c))
        x = x.reshape(obs.shape[0], obs.shape[1], -1)
        h_in = hidden_state
        gru_out, _ = self.rnn(x, h_in)
        gru_out_c = gru_out.reshape(-1, gru_out.shape[-1])

        indi = self.indi_net(gru_out_c)
        indi_mu = self.indi_mu(indi) # b*a, indi_latent_dim
        indi_lnsigma2 = self.indi_lnsigma2(indi)

        indi_latent_dis = self.reparametrize(indi_mu,indi_lnsigma2)  # b*a, indi_latent_dim
        latent_embed = D.Normal(indi_mu, th.exp(indi_lnsigma2/2))
        if s!=None:
            s_f=torch.cat((s,u.reshape(u.shape[0],u.shape[1],-1)),dim=-1)
            #obs_f=torch.cat((obs,u.permute(0,2,1,3).reshape(-1,u.shape[1],u.shape[-1])),dim=-1)
            s_h=self.future_s_h(s_f, obs)
            #obs_h=self.future_obs_h(obs_f)
            # w_s,b_s=self.past_obs_h(obs,b,a)
            # s_h = torch.matmul(s_h.unsqueeze(1), w_s) + b_s

            r_MI_s, s_rew=self.get_s_MI(gru_out_c,s_h,latent_embed,b,a)
            #r_MI_obs, obs_rew=self.get_obs_MI(gru_out_c,obs_h,latent_embed,b,a)
            r_MI=self.args.beta1*r_MI_s#+self.args.beta2*r_MI_obs
            r_MI=r_MI.clamp(max=20)


        latent_para = self.params_dis_net(indi_latent_dis) #b*a,e#.clone().detach()
        fc2_w = self.fc2_w_nn(latent_para) # b*a,e*n_actions
        fc2_b = self.fc2_b_nn(latent_para)# b*a,n_actions
        fc2_w = fc2_w.view(b*a,self.args.rnn_hidden_dim,self.args.n_actions)# b*a, e,n_actions
        fc2_b = fc2_b.view(b*a,1,self.args.n_actions)# b*a, 1,n_actions
        h = gru_out.reshape(b*a,1,self.args.rnn_hidden_dim) # b*a,1,e
        q_2 = torch.matmul(h,fc2_w) + fc2_b  #.clone().detach()
        q_1 = self.fc2(h)
        q=q_1+q_2
        #q=q_2
        q = q.reshape(obs.shape[0], obs.shape[1], -1)  # 415
        if excute_label:
            return q.squeeze(),gru_out.permute(1,0,2)
        return q, gru_out, r_MI,s_rew,obs_rew
    def reparametrize(self, mu, logvar):
        # std = logvar.mul(0.5).exp_()
        # if self.args.cuda:
        #     eps = Variable(torch.randn(mu.size(0), mu.size(1))).cuda()
        #     # eps = torch.cuda.FloatTensor(mu.size()).normal_()
        # else:
        origin_size = mu.size()
        mu = mu.view(-1,mu.size(-1))
        logvar = logvar.view(-1,logvar.size(-1))

        eps = Variable(torch.randn(mu.size(0), mu.size(1))).to(mu.device) #标准正太分布
        #logvar: ln var^2
            # eps = torch.FloatTensor(mu.size()).normal_()
        z = mu + 0.001*eps*torch.exp(logvar/2)
        z=z.view(origin_size)
        return z

    def update(self, inputs,vae_mu,vae_sigma, mask):
        pass
class RNN_future_obs(nn.Module):
    # 不加通信
    def __init__(self, input_shape, args):
        super(RNN_future_obs, self).__init__()
        self.args = args
        setup_seed(args.seed)
        self.input_shape = input_shape
        self.n_agents=args.n_agents
        self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim)

        # self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim)
        self.rnn = nn.GRU(
            input_size=args.rnn_hidden_dim,
            num_layers=1,
            hidden_size=args.rnn_hidden_dim,
            batch_first=True,
        )
        self.fc2 = nn.Linear(args.rnn_hidden_dim, args.n_actions)


        self.indi_net = nn.Sequential(nn.Linear(args.rnn_hidden_dim*1,args.rnn_hidden_dim), nn.ReLU(inplace=True))
        self.indi_mu = nn.Linear(args.rnn_hidden_dim, args.indi_latent_dim)
        self.indi_lnsigma2 = nn.Linear(args.rnn_hidden_dim, args.indi_latent_dim)

        self.params_net = nn.Sequential(nn.Linear(args.indi_latent_dim*2,args.rnn_hidden_dim), nn.ReLU(inplace=True))
        self.params_dis_net = nn.Sequential(nn.Linear(args.indi_latent_dim*1,args.rnn_hidden_dim), nn.ReLU(inplace=True))

        self.fc2_w_nn = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim * args.n_actions)
        self.fc2_b_nn = nn.Linear(args.rnn_hidden_dim, args.n_actions)

        self.infer_net_s = nn.Sequential(nn.Linear(args.rnn_hidden_dim * 2, args.rnn_hidden_dim), nn.ReLU(inplace=True))
        self.infer_mu_s = nn.Linear(args.rnn_hidden_dim, args.indi_latent_dim)
        self.infer_lnsigma_s = nn.Linear(args.rnn_hidden_dim, args.indi_latent_dim)
        self.pre_rew_s=nn.Linear(args.indi_latent_dim, 1)


        self.infer_net_obs = nn.Sequential(nn.Linear(args.rnn_hidden_dim * 2, args.rnn_hidden_dim), nn.ReLU(inplace=True))
        self.infer_mu_obs = nn.Linear(args.rnn_hidden_dim, args.indi_latent_dim)
        self.infer_lnsigma_obs = nn.Linear(args.rnn_hidden_dim, args.indi_latent_dim)
        self.pre_rew_obs=nn.Linear(args.indi_latent_dim, 1)


        self.fc_state = nn.Sequential(nn.Linear(args.state_shape+args.n_agents*args.n_actions,args.rnn_hidden_dim), nn.ReLU(inplace=True))
        self.fc_obs = nn.Sequential(nn.Linear(input_shape+args.n_actions,args.rnn_hidden_dim), nn.ReLU(inplace=True))

        # self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim)
        self.rnn_state = nn.GRU(
            input_size=args.rnn_hidden_dim,
            num_layers=1,
            hidden_size=args.rnn_hidden_dim,
            batch_first=True,
        )
        self.rnn_obs = nn.GRU(
            input_size=args.rnn_hidden_dim,
            num_layers=1,
            hidden_size=args.rnn_hidden_dim,
            batch_first=True,
        )
        self.fc_obs_state = nn.Sequential(nn.Linear(args.state_shape+args.n_actions,args.rnn_hidden_dim), nn.ReLU(inplace=True))
        # self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim)
        self.obs_wb = nn.Linear(input_shape, args.rnn_hidden_dim)
        self.obs_w_nn = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim **2)
        self.obs_b_nn = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim)
        self.rnn_pastobs = nn.GRU(
            input_size=args.rnn_hidden_dim,
            num_layers=1,
            hidden_size=args.rnn_hidden_dim,
            batch_first=True,
        )
        self.norm=DynamicNorm(input_shape, only_for_last_dim=True, exclude_one_hot=True, exclude_nan=True)
    def init_hidden(self):
        # make hidden states on same device as model
        return self.fc1.weight.new(1, self.args.rnn_hidden_dim).zero_()
    def flipUp(self,s_in,mask):
        mask=mask.repeat(1,1,s_in.shape[-1])
        s_new=s_in.clone().flip(dims=[1])#.cpu().numpy()
        s_flip=torch.zeros_like(s_in)
        mask_flip=mask.flip(dims=[1])
        s_flip[mask.bool()]=s_new[mask_flip.bool()]
        return s_flip
    def past_obs_h(self, obs,b,a):

        x=self.obs_wb(obs)
        fc2_w = self.obs_w_nn(x) # b*a,e*n_actions
        fc2_b = self.obs_b_nn(x)# b*a,n_actions
        fc2_w = fc2_w.view(b*a,self.args.rnn_hidden_dim,self.args.rnn_hidden_dim)# b*a, e,n_actions
        fc2_b = fc2_b.view(b*a,1,self.args.rnn_hidden_dim)# b*a, 1,n_actions

        return fc2_w,fc2_b
    def future_obs_h(self, obs):
        hidden_state=torch.zeros((1, obs.shape[0], self.args.rnn_hidden_dim)).to(obs.device)
        mask = torch.any(obs.bool(), dim=-1).float().unsqueeze(-1)
        #s_flip=s
        obs_h=self.flipUp(obs.clone(), mask)
        obs_h=self.fc_obs(obs_h)
        obs_h,_=self.rnn_obs(obs_h,hidden_state)
        obs_h=self.flipUp(obs_h,mask)
        return obs_h.reshape(-1, obs_h.shape[-1])
    def future_s_h(self, s, obs):
        hidden_state=torch.zeros((1,s.shape[0], self.args.rnn_hidden_dim)).to(s.device)
        mask = torch.any(s.bool(), dim=-1).float().unsqueeze(-1)
        #s_flip=s
        s_h=self.flipUp(s.clone(),mask)
        s_h=self.fc_state(s_h)
        s_h,_=self.rnn_state(s_h,hidden_state)
        s_h=self.flipUp(s_h,mask)
        return s_h.unsqueeze(1).repeat(1,self.n_agents,1,1).reshape(-1, s_h.shape[-1])
    def get_s_MI(self,gru_out_c,s_h,latent_embed,b,a):
        latent_infer = self.infer_net_s(th.cat([gru_out_c.detach(), s_h], dim=1))
        infer_mu_s = self.infer_mu_s(latent_infer)  # b*a, indi_latent_dim
        infer_lnsigma_s = self.infer_lnsigma_s(latent_infer)  # b*a, indi_latent_dim
        latent_infer_embed_s = D.Normal(infer_mu_s.detach(), th.exp(infer_lnsigma_s.detach() / 2))
        rew_latent_dis = self.reparametrize(infer_mu_s, infer_lnsigma_s)
        s_rew = self.pre_rew_s(rew_latent_dis).reshape(-1, self.n_agents, a, 1).mean(dim=1)
        r_MI = kl_divergence(latent_embed, latent_infer_embed_s).sum(dim=-1).reshape(b, a)
        return r_MI,s_rew
    def get_obs_MI(self, gru_out_c, obs_h, latent_embed, b, a):
        latent_infer = self.infer_net_obs(th.cat([gru_out_c.detach(), obs_h], dim=1))
        infer_mu_obs = self.infer_mu_obs(latent_infer)  # b*a, indi_latent_dim
        infer_lnsigma_obs = self.infer_lnsigma_obs(latent_infer)  # b*a, indi_latent_dim
        latent_infer_embed_obs = D.Normal(infer_mu_obs.detach(), th.exp(infer_lnsigma_obs.detach() / 2))
        rew_latent_dis = self.reparametrize(infer_mu_obs, infer_lnsigma_obs)
        obs_rew = self.pre_rew_obs(rew_latent_dis).reshape(-1, self.n_agents, a, 1).mean(dim=1)
        r_MI = kl_divergence(latent_embed, latent_infer_embed_obs).sum(dim=-1).reshape(b, a)
        return r_MI,obs_rew
    def forward(self, obs, hidden_state,s=None,u=None):
        b, a, e = obs.size() # 平行环境，智能体个数，观测数量

        excute_label = False
        if len(hidden_state.shape) == 2:
            hidden_state = hidden_state.unsqueeze(0)  # 保证尺寸(1,3,64)
        if obs.shape[0]== self.n_agents:
            excute_label = True

        #########################################
        r_MI=None
        s_rew=None
        obs_rew=None


        obs_c = obs.view(-1, obs.shape[-1])
        x = f.relu(self.fc1(obs_c))
        x = x.reshape(obs.shape[0], obs.shape[1], -1)
        h_in = hidden_state
        gru_out, _ = self.rnn(x, h_in)
        gru_out_c = gru_out.reshape(-1, gru_out.shape[-1])

        indi = self.indi_net(gru_out_c)
        indi_mu = self.indi_mu(indi) # b*a, indi_latent_dim
        indi_lnsigma2 = self.indi_lnsigma2(indi)  # b*a, indi_latent_dim
        indi_latent_dis = self.reparametrize(indi_mu,indi_lnsigma2)  # b*a, indi_latent_dim
        latent_embed = D.Normal(indi_mu, th.exp(indi_lnsigma2/2))
        if s!=None:
            #s_f=torch.cat((s,u.reshape(u.shape[0],u.shape[1],-1)),dim=-1)
            obs_f=torch.cat((obs,u.permute(0,2,1,3).reshape(-1,u.shape[1],u.shape[-1])),dim=-1)
            #s_h=self.future_s_h(s_f, obs)
            obs_h=self.future_obs_h(obs_f)
            # w_s,b_s=self.past_obs_h(obs,b,a)
            # s_h = torch.matmul(s_h.unsqueeze(1), w_s) + b_s

            #r_MI_s, s_rew=self.get_s_MI(gru_out_c,s_h,latent_embed,b,a)
            r_MI_obs, obs_rew=self.get_obs_MI(gru_out_c,obs_h,latent_embed,b,a)
            r_MI=self.args.beta2*r_MI_obs
            r_MI=r_MI.clamp(max=20)


        latent_para = self.params_dis_net(indi_latent_dis) #b*a,e#.clone().detach()
        fc2_w = self.fc2_w_nn(latent_para) # b*a,e*n_actions
        fc2_b = self.fc2_b_nn(latent_para)# b*a,n_actions
        fc2_w = fc2_w.view(b*a,self.args.rnn_hidden_dim,self.args.n_actions)# b*a, e,n_actions
        fc2_b = fc2_b.view(b*a,1,self.args.n_actions)# b*a, 1,n_actions
        h = gru_out.reshape(b*a,1,self.args.rnn_hidden_dim) # b*a,1,e
        q_2 = torch.matmul(h,fc2_w) + fc2_b  #.clone().detach()
        q_1 = self.fc2(h)
        q=q_1+q_2
        q = q.reshape(obs.shape[0], obs.shape[1], -1)  # 415
        if excute_label:
            return q.squeeze(),gru_out.permute(1,0,2)
        return q, gru_out, r_MI,s_rew,obs_rew
    def reparametrize(self, mu, logvar):
        # std = logvar.mul(0.5).exp_()
        # if self.args.cuda:
        #     eps = Variable(torch.randn(mu.size(0), mu.size(1))).cuda()
        #     # eps = torch.cuda.FloatTensor(mu.size()).normal_()
        # else:
        origin_size = mu.size()
        mu = mu.view(-1,mu.size(-1))
        logvar = logvar.view(-1,logvar.size(-1))

        eps = Variable(torch.randn(mu.size(0), mu.size(1))).to(mu.device) #标准正太分布
        #logvar: ln var^2
            # eps = torch.FloatTensor(mu.size()).normal_()
        z = mu + 0.001*eps*torch.exp(logvar/2)
        z=z.view(origin_size)
        return z

    def update(self, inputs,vae_mu,vae_sigma, mask):
        pass
class old(nn.Module):
    # 不加通信
    def __init__(self, input_shape, args):
        super(RNN_future_s_obs, self).__init__()
        self.args = args
        setup_seed(args.seed)
        self.input_shape = input_shape
        self.n_agents=args.n_agents
        self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim)

        # self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim)
        self.rnn = nn.GRU(
            input_size=args.rnn_hidden_dim,
            num_layers=1,
            hidden_size=args.rnn_hidden_dim,
            batch_first=True,
        )
        self.fc2 = nn.Linear(args.rnn_hidden_dim*3, args.n_actions)


        self.indi_net = nn.Sequential(nn.Linear(args.rnn_hidden_dim*1,args.rnn_hidden_dim), nn.ReLU(inplace=True))
        self.indi_mu = nn.Linear(args.rnn_hidden_dim, args.indi_latent_dim)
        self.indi_lnsigma2 = nn.Linear(args.rnn_hidden_dim, args.indi_latent_dim)

        self.id_indi_net = nn.Sequential(nn.Linear(args.rnn_hidden_dim*1,args.rnn_hidden_dim), nn.ReLU(inplace=True))
        self.id_indi_mu = nn.Linear(args.rnn_hidden_dim, args.indi_latent_dim)
        self.id_indi_lnsigma2 = nn.Linear(args.rnn_hidden_dim, args.indi_latent_dim)


        self.params_net = nn.Sequential(nn.Linear(args.indi_latent_dim*2,args.rnn_hidden_dim), nn.ReLU(inplace=True))
        self.params_dis_net = nn.Sequential(nn.Linear(args.indi_latent_dim*2,args.rnn_hidden_dim), nn.ReLU(inplace=True))

        self.fc2_w_nn = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim * args.n_actions)
        self.fc2_b_nn = nn.Linear(args.rnn_hidden_dim, args.n_actions)

        self.infer_net_s = nn.Sequential(nn.Linear(args.rnn_hidden_dim * 2, args.rnn_hidden_dim), nn.ReLU(inplace=True))
        self.infer_mu_s = nn.Linear(args.rnn_hidden_dim, args.indi_latent_dim)
        self.infer_lnsigma_s = nn.Linear(args.rnn_hidden_dim, args.indi_latent_dim)
        self.pre_rew_s=nn.Linear(args.indi_latent_dim, 1)



        self.infer_net_obs = nn.Sequential(nn.Linear(args.rnn_hidden_dim * 2, args.rnn_hidden_dim), nn.ReLU(inplace=True))
        self.infer_mu_obs = nn.Linear(args.rnn_hidden_dim, args.indi_latent_dim)
        self.infer_lnsigma_obs = nn.Linear(args.rnn_hidden_dim, args.indi_latent_dim)
        self.pre_rew_obs=nn.Linear(args.indi_latent_dim, 1)

        self.g_params_net = nn.Sequential(nn.Linear(args.rnn_hidden_dim,args.rnn_hidden_dim), nn.ReLU(inplace=True))

        self.fc_state = nn.Sequential(nn.Linear(args.state_shape+args.n_agents*args.n_actions,args.rnn_hidden_dim), nn.ReLU(inplace=True))
        self.fc_obs = nn.Sequential(nn.Linear(input_shape+args.n_actions,args.rnn_hidden_dim), nn.ReLU(inplace=True))

        # self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim)
        self.rnn_state = nn.GRU(
            input_size=args.rnn_hidden_dim,
            num_layers=1,
            hidden_size=args.rnn_hidden_dim,
            batch_first=True,
        )
        self.rnn_obs = nn.GRU(
            input_size=args.rnn_hidden_dim,
            num_layers=1,
            hidden_size=args.rnn_hidden_dim,
            batch_first=True,
        )
        self.fc_obs_state = nn.Sequential(nn.Linear(args.state_shape+args.n_actions,args.rnn_hidden_dim), nn.ReLU(inplace=True))
        # self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim)
        self.obs_wb = nn.Linear(input_shape, args.rnn_hidden_dim)
        self.obs_w_nn = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim **2)
        self.obs_b_nn = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim)
        self.rnn_pastobs = nn.GRU(
            input_size=args.rnn_hidden_dim,
            num_layers=1,
            hidden_size=args.rnn_hidden_dim,
            batch_first=True,
        )
        self.norm=DynamicNorm(input_shape, only_for_last_dim=True, exclude_one_hot=True, exclude_nan=True)
    def init_hidden(self):
        # make hidden states on same device as model
        return self.fc1.weight.new(1, self.args.rnn_hidden_dim).zero_()
    def flipUp(self,s_in,mask):
        mask=mask.repeat(1,1,s_in.shape[-1])
        s_new=s_in.clone().flip(dims=[1])#.cpu().numpy()
        s_flip=torch.zeros_like(s_in)
        mask_flip=mask.flip(dims=[1])
        s_flip[mask.bool()]=s_new[mask_flip.bool()]
        return s_flip
    def past_obs_h(self, obs,b,a):

        x=self.obs_wb(obs)
        fc2_w = self.obs_w_nn(x) # b*a,e*n_actions
        fc2_b = self.obs_b_nn(x)# b*a,n_actions
        fc2_w = fc2_w.view(b*a,self.args.rnn_hidden_dim,self.args.rnn_hidden_dim)# b*a, e,n_actions
        fc2_b = fc2_b.view(b*a,1,self.args.rnn_hidden_dim)# b*a, 1,n_actions

        return fc2_w,fc2_b
    def future_obs_h(self, obs):
        hidden_state=torch.zeros((1, obs.shape[0], self.args.rnn_hidden_dim)).to(obs.device)
        mask = torch.any(obs.bool(), dim=-1).float().unsqueeze(-1)
        #s_flip=s
        obs_h=self.flipUp(obs.clone(), mask)
        obs_h=self.fc_obs(obs_h)
        obs_h,_=self.rnn_obs(obs_h,hidden_state)
        obs_h=self.flipUp(obs_h,mask)
        return obs_h.reshape(-1, obs_h.shape[-1])
    def future_s_h(self, s, obs):
        hidden_state=torch.zeros((1,s.shape[0], self.args.rnn_hidden_dim)).to(s.device)
        mask = torch.any(s.bool(), dim=-1).float().unsqueeze(-1)
        #s_flip=s
        s_h=self.flipUp(s.clone(),mask)
        s_h=self.fc_state(s_h)
        s_h,_=self.rnn_state(s_h,hidden_state)
        s_h=self.flipUp(s_h,mask)
        return s_h.unsqueeze(1).repeat(1,self.n_agents,1,1).reshape(-1, s_h.shape[-1])
    def get_s_MI(self,gru_out_c,s_h,latent_embed,b,a):
        latent_infer = self.infer_net_s(th.cat([gru_out_c.detach(), s_h], dim=1))
        infer_mu_s = self.infer_mu_s(latent_infer)  # b*a, indi_latent_dim
        infer_lnsigma_s = self.infer_lnsigma_s(latent_infer)  # b*a, indi_latent_dim
        latent_infer_embed_s = D.Normal(infer_mu_s, th.exp(infer_lnsigma_s / 2))
        rew_latent_dis = self.reparametrize(infer_mu_s, infer_lnsigma_s)
        s_rew = self.pre_rew_s(rew_latent_dis).reshape(-1, self.n_agents, a, 1).mean(dim=1)
        r_MI = kl_divergence(latent_embed, latent_infer_embed_s).sum(dim=-1).reshape(b, a)
        return r_MI,s_rew
    def get_obs_MI(self, gru_out_c, obs_h, latent_embed, b, a):
        latent_infer = self.infer_net_obs(th.cat([gru_out_c.detach(), obs_h], dim=1))
        infer_mu_obs = self.infer_mu_obs(latent_infer)  # b*a, indi_latent_dim
        infer_lnsigma_obs = self.infer_lnsigma_obs(latent_infer)  # b*a, indi_latent_dim
        latent_infer_embed_obs = D.Normal(infer_mu_obs, th.exp(infer_lnsigma_obs / 2))
        rew_latent_dis = self.reparametrize(infer_mu_obs, infer_lnsigma_obs)
        obs_rew = self.pre_rew_obs(rew_latent_dis).reshape(-1, self.n_agents, a, 1).mean(dim=1)
        r_MI = kl_divergence(latent_embed, latent_infer_embed_obs).sum(dim=-1).reshape(b, a)
        return r_MI,obs_rew
    def forward(self, obs, hidden_state,s=None,u=None):
        b, a, e = obs.size() # 平行环境，智能体个数，观测数量

        excute_label = False
        if len(hidden_state.shape) == 2:
            hidden_state = hidden_state.unsqueeze(0)  # 保证尺寸(1,3,64)
        if obs.shape[0]== self.n_agents:
            excute_label = True

        #########################################
        r_MI=None
        s_rew=None
        obs_rew=None


        obs_c = obs.view(-1, obs.shape[-1])
        x = f.relu(self.fc1(obs_c))
        x = x.reshape(obs.shape[0], obs.shape[1], -1)
        h_in = hidden_state
        gru_out, _ = self.rnn(x, h_in)
        gru_out_c = gru_out.reshape(-1, gru_out.shape[-1])

        indi = self.indi_net(gru_out_c)
        indi_mu = self.indi_mu(indi) # b*a, indi_latent_dim
        indi_lnsigma2 = self.indi_lnsigma2(indi)  # b*a, indi_latent_dim
        indi_latent_dis = self.reparametrize(indi_mu,indi_lnsigma2)  # b*a, indi_latent_dim
        latent_embed = D.Normal(indi_mu, th.exp(indi_lnsigma2/2))


        id_indi = self.id_indi_net(gru_out_c)

        id_indi_mu = self.id_indi_mu(id_indi) # b*a, indi_latent_dim
        id_indi_lnsigma2 = self.id_indi_lnsigma2(id_indi)  # b*a, indi_latent_dim
        id_indi_latent_dis = self.reparametrize(id_indi_mu,id_indi_lnsigma2)  # b*a, indi_latent_dim
        obs_latent_embed = D.Normal(id_indi_mu, th.exp(id_indi_lnsigma2/2))

        if s!=None:
            s_f=torch.cat((s,u.reshape(u.shape[0],u.shape[1],-1)),dim=-1)
            obs_f=torch.cat((obs,u.permute(0,2,1,3).reshape(-1,u.shape[1],u.shape[-1])),dim=-1)
            s_h=self.future_s_h(s_f, obs)
            obs_h=self.future_obs_h(obs_f)
            # w_s,b_s=self.past_obs_h(obs,b,a)
            # s_h = torch.matmul(s_h.unsqueeze(1), w_s) + b_s

            r_MI_s, s_rew=self.get_s_MI(gru_out_c,s_h,latent_embed,b,a)
            r_MI_obs, obs_rew=self.get_obs_MI(gru_out_c,obs_h,obs_latent_embed,b,a)
            r_MI=self.args.beta1*r_MI_s+self.args.beta2*r_MI_obs
            #r_MI=r_MI.clamp(max=10)


        g_latent_para = self.g_params_net(indi_latent_dis)
        id_latent_para = self.g_params_net(id_indi_latent_dis)

        latent_para = self.params_dis_net(th.cat([g_latent_para, id_latent_para], dim=1))
        # fc2_w = self.fc2_w_nn(latent_para) # b*a,e*n_actions
        # fc2_b = self.fc2_b_nn(latent_para)# b*a,n_actions
        # fc2_w = fc2_w.view(b*a,self.args.rnn_hidden_dim,self.args.n_actions)# b*a, e,n_actions
        # fc2_b = fc2_b.view(b*a,1,self.args.n_actions)# b*a, 1,n_actions
        h = gru_out.reshape(b*a,1,self.args.rnn_hidden_dim) # b*a,1,e
        # q_2 = torch.matmul(h,fc2_w) + fc2_b  #.clone().detach()
        q_1 = self.fc2(th.cat([gru_out_c,g_latent_para, id_latent_para], dim=1))
        q=q_1#+q_2
        q = q.reshape(obs.shape[0], obs.shape[1], -1)  # 415
        if excute_label:
            return q.squeeze(),gru_out.permute(1,0,2)
        return q, gru_out, r_MI,s_rew,obs_rew
    def reparametrize(self, mu, logvar):
        # std = logvar.mul(0.5).exp_()
        # if self.args.cuda:
        #     eps = Variable(torch.randn(mu.size(0), mu.size(1))).cuda()
        #     # eps = torch.cuda.FloatTensor(mu.size()).normal_()
        # else:
        origin_size = mu.size()
        mu = mu.view(-1,mu.size(-1))
        logvar = logvar.view(-1,logvar.size(-1))

        eps = Variable(torch.randn(mu.size(0), mu.size(1))).to(mu.device) #标准正太分布
        #logvar: ln var^2
            # eps = torch.FloatTensor(mu.size()).normal_()
        z = mu + 0.001*eps*torch.exp(logvar/2)
        z=z.view(origin_size)
        return z

    def update(self, inputs,vae_mu,vae_sigma, mask):
        pass
class RNN_future_s_obs(nn.Module):
    # 不加通信
    def __init__(self, input_shape, args):
        super(RNN_future_s_obs, self).__init__()
        self.args = args
        setup_seed(args.seed)
        self.input_shape = input_shape
        self.n_agents=args.n_agents
        self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim)
        self.detach_label=True if (self.args.GRF or 'pacmen' in args.env) else False
        # self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim)
        self.rnn = nn.GRU(
            input_size=args.rnn_hidden_dim,
            num_layers=1,
            hidden_size=args.rnn_hidden_dim,
            batch_first=True,
        )
        self.fc2 = nn.Linear(args.rnn_hidden_dim, args.n_actions)


        self.indi_net = nn.Sequential(nn.Linear(args.rnn_hidden_dim*1,args.rnn_hidden_dim), nn.ReLU(inplace=True))
        self.indi_mu = nn.Linear(args.rnn_hidden_dim, args.indi_latent_dim)
        self.indi_lnsigma2 = nn.Linear(args.rnn_hidden_dim, args.indi_latent_dim)

        self.id_indi_net = nn.Sequential(nn.Linear(args.rnn_hidden_dim*1,args.rnn_hidden_dim), nn.ReLU(inplace=True))
        self.id_indi_mu = nn.Linear(args.rnn_hidden_dim, args.indi_latent_dim)
        self.id_indi_lnsigma2 = nn.Linear(args.rnn_hidden_dim, args.indi_latent_dim)


        self.params_net = nn.Sequential(nn.Linear(args.indi_latent_dim*2,args.rnn_hidden_dim), nn.ReLU(inplace=True))
        self.params_dis_net = nn.Sequential(nn.Linear(args.indi_latent_dim*2,args.rnn_hidden_dim), nn.ReLU(inplace=True))

        self.fc2_w_nn = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim * args.n_actions)
        self.fc2_b_nn = nn.Linear(args.rnn_hidden_dim, args.n_actions)

        self.infer_net_s = nn.Sequential(nn.Linear(args.rnn_hidden_dim * 2, args.rnn_hidden_dim), nn.ReLU(inplace=True))
        self.infer_mu_s = nn.Linear(args.rnn_hidden_dim, args.indi_latent_dim)
        self.infer_lnsigma_s = nn.Linear(args.rnn_hidden_dim, args.indi_latent_dim)
        self.pre_rew_s=nn.Linear(args.indi_latent_dim, 1)



        self.infer_net_obs = nn.Sequential(nn.Linear(args.rnn_hidden_dim * 2, args.rnn_hidden_dim), nn.ReLU(inplace=True))
        self.infer_mu_obs = nn.Linear(args.rnn_hidden_dim, args.indi_latent_dim)
        self.infer_lnsigma_obs = nn.Linear(args.rnn_hidden_dim, args.indi_latent_dim)
        self.pre_rew_obs=nn.Linear(args.indi_latent_dim, 1)

        self.g_params_net = nn.Sequential(nn.Linear(args.rnn_hidden_dim,args.rnn_hidden_dim), nn.ReLU(inplace=True))

        self.fc_state = nn.Sequential(nn.Linear(args.state_shape+args.n_agents*args.n_actions,args.rnn_hidden_dim), nn.ReLU(inplace=True))
        self.fc_obs = nn.Sequential(nn.Linear(input_shape+args.n_actions,args.rnn_hidden_dim), nn.ReLU(inplace=True))

        # self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim)
        self.rnn_state = nn.GRU(
            input_size=args.rnn_hidden_dim,
            num_layers=1,
            hidden_size=args.rnn_hidden_dim,
            batch_first=True,
        )
        self.rnn_obs = nn.GRU(
            input_size=args.rnn_hidden_dim,
            num_layers=1,
            hidden_size=args.rnn_hidden_dim,
            batch_first=True,
        )
        self.fc_obs_state = nn.Sequential(nn.Linear(args.state_shape+args.n_actions,args.rnn_hidden_dim), nn.ReLU(inplace=True))
        # self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim)
        self.obs_wb = nn.Linear(input_shape, args.rnn_hidden_dim)
        self.obs_w_nn = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim **2)
        self.obs_b_nn = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim)
        self.rnn_pastobs = nn.GRU(
            input_size=args.rnn_hidden_dim,
            num_layers=1,
            hidden_size=args.rnn_hidden_dim,
            batch_first=True,
        )
        self.save_mu=[]
        self.norm=DynamicNorm(input_shape, only_for_last_dim=True, exclude_one_hot=True, exclude_nan=True)
    def init_hidden(self):
        # make hidden states on same device as model
        return self.fc1.weight.new(1, self.args.rnn_hidden_dim).zero_()
    def flipUp(self,s_in,mask):
        mask=mask.repeat(1,1,s_in.shape[-1])
        s_new=s_in.clone().flip(dims=[1])#.cpu().numpy()
        s_flip=torch.zeros_like(s_in)
        mask_flip=mask.flip(dims=[1])
        s_flip[mask.bool()]=s_new[mask_flip.bool()]
        return s_flip
    def past_obs_h(self, obs,b,a):

        x=self.obs_wb(obs)
        fc2_w = self.obs_w_nn(x) # b*a,e*n_actions
        fc2_b = self.obs_b_nn(x)# b*a,n_actions
        fc2_w = fc2_w.view(b*a,self.args.rnn_hidden_dim,self.args.rnn_hidden_dim)# b*a, e,n_actions
        fc2_b = fc2_b.view(b*a,1,self.args.rnn_hidden_dim)# b*a, 1,n_actions

        return fc2_w,fc2_b
    def future_obs_h(self, obs):
        hidden_state=torch.zeros((1, obs.shape[0], self.args.rnn_hidden_dim)).to(obs.device)
        mask = torch.any(obs.bool(), dim=-1).float().unsqueeze(-1)
        #s_flip=s
        obs_h=self.flipUp(obs.clone(), mask)
        obs_h=self.fc_obs(obs_h)
        obs_h,_=self.rnn_obs(obs_h,hidden_state)
        obs_h=self.flipUp(obs_h,mask)
        return obs_h.reshape(-1, obs_h.shape[-1])
    def obs_h(self, obs):
        obs_h=obs
        obs_h=self.fc_obs(obs_h)
        return obs_h.reshape(-1, obs_h.shape[-1])
    def s_h(self, s, obs):
        s_h=s
        s_h=self.fc_state(s_h)
        return s_h.unsqueeze(1).repeat(1,self.n_agents,1,1).reshape(-1, s_h.shape[-1])
    def future_s_h(self, s, obs):
        hidden_state=torch.zeros((1,s.shape[0], self.args.rnn_hidden_dim)).to(s.device)
        mask = torch.any(s.bool(), dim=-1).float().unsqueeze(-1)
        #s_flip=s
        s_h=self.flipUp(s.clone(),mask)
        s_h=self.fc_state(s_h)
        s_h,_=self.rnn_state(s_h,hidden_state)
        s_h=self.flipUp(s_h,mask)
        return s_h.unsqueeze(1).repeat(1,self.n_agents,1,1).reshape(-1, s_h.shape[-1])
    def get_s_MI(self,gru_out_c,s_h,latent_embed,b,a):
        latent_infer = self.infer_net_s(th.cat([gru_out_c.detach(), s_h], dim=1))
        infer_mu_s = self.infer_mu_s(latent_infer)  # b*a, indi_latent_dim
        infer_lnsigma_s = self.infer_lnsigma_s(latent_infer)  # b*a, indi_latent_dim
        if self.detach_label:
            latent_infer_embed_s = D.Normal(infer_mu_s, th.exp(infer_lnsigma_s / 2))
        else:
            latent_infer_embed_s = D.Normal(infer_mu_s.detach(), th.exp(infer_lnsigma_s.detach() / 2))
        #

        rew_latent_dis = self.reparametrize(infer_mu_s, infer_lnsigma_s)
        s_rew = self.pre_rew_s(rew_latent_dis).reshape(-1, self.n_agents, a, 1).mean(dim=1)
        r_MI = kl_divergence(latent_embed, latent_infer_embed_s).sum(dim=-1).reshape(b, a)
        return r_MI,s_rew
    def get_obs_MI(self, gru_out_c, obs_h, latent_embed, b, a):
        latent_infer = self.infer_net_obs(th.cat([gru_out_c.detach(), obs_h], dim=1))
        infer_mu_obs = self.infer_mu_obs(latent_infer)  # b*a, indi_latent_dim
        infer_lnsigma_obs = self.infer_lnsigma_obs(latent_infer)  # b*a, indi_latent_dim
        if self.detach_label:
            latent_infer_embed_obs = D.Normal(infer_mu_obs, th.exp(infer_lnsigma_obs / 2))
        else:
            latent_infer_embed_obs = D.Normal(infer_mu_obs.detach(), th.exp(infer_lnsigma_obs.detach() / 2))
        #latent_infer_embed_obs = D.Normal(infer_mu_obs, th.exp(infer_lnsigma_obs / 2))
        rew_latent_dis = self.reparametrize(infer_mu_obs, infer_lnsigma_obs)
        obs_rew = self.pre_rew_obs(rew_latent_dis).reshape(-1, self.n_agents, a, 1).mean(dim=1)
        r_MI = kl_divergence(latent_embed, latent_infer_embed_obs).sum(dim=-1).reshape(b, a)
        return r_MI,obs_rew
    def forward(self, obs, hidden_state,s=None,u=None):
        b, a, e = obs.size() # 平行环境，智能体个数，观测数量

        excute_label = False
        if len(hidden_state.shape) == 2:
            hidden_state = hidden_state.unsqueeze(0)  # 保证尺寸(1,3,64)
        if obs.shape[0]== self.n_agents:
            excute_label = True

        #########################################
        r_MI=None
        s_rew=None
        obs_rew=None


        obs_c = obs.view(-1, obs.shape[-1])
        x = f.relu(self.fc1(obs_c))
        x = x.reshape(obs.shape[0], obs.shape[1], -1)
        h_in = hidden_state
        gru_out, _ = self.rnn(x, h_in)
        gru_out_c = gru_out.reshape(-1, gru_out.shape[-1])

        indi = self.indi_net(gru_out_c)
        indi_mu = self.indi_mu(indi) # b*a, indi_latent_dim
        indi_lnsigma2 = self.indi_lnsigma2(indi)  # b*a, indi_latent_dim
        indi_latent_dis = self.reparametrize(indi_mu,indi_lnsigma2)  # b*a, indi_latent_dim
        latent_embed = D.Normal(indi_mu, th.exp(indi_lnsigma2/2))


        id_indi = self.id_indi_net(gru_out_c)

        id_indi_mu = self.id_indi_mu(id_indi) # b*a, indi_latent_dim
        id_indi_lnsigma2 = self.id_indi_lnsigma2(id_indi)  # b*a, indi_latent_dim
        id_indi_latent_dis = self.reparametrize(id_indi_mu,id_indi_lnsigma2)  # b*a, indi_latent_dim
        obs_latent_embed = D.Normal(id_indi_mu, th.exp(id_indi_lnsigma2/2))

            #
        if s!=None:
            s_f=torch.cat((s,u.reshape(u.shape[0],u.shape[1],-1)),dim=-1)
            obs_f=torch.cat((obs,u.permute(0,2,1,3).reshape(-1,u.shape[1],u.shape[-1])),dim=-1)
            s_h=self.future_s_h(s_f, obs)
            obs_h=self.future_obs_h(obs_f)
            # s_h=self.s_h(s_f, obs)
            # obs_h=self.obs_h(obs_f)
            # w_s,b_s=self.past_obs_h(obs,b,a)
            # s_h = torch.matmul(s_h.unsqueeze(1), w_s) + b_s
            KLD_s = 0.5 * torch.sum(torch.exp(indi_lnsigma2) + torch.pow(indi_mu, 2) - 1. - indi_lnsigma2, axis=-1,
                                    keepdim=True)
            KLD_obs = 0.5 * torch.sum(torch.exp(id_indi_lnsigma2) + torch.pow(id_indi_mu, 2) - 1. - id_indi_lnsigma2,
                                      axis=-1, keepdim=True)

            r_MI_s, s_rew=self.get_s_MI(gru_out_c,s_h,latent_embed,b,a)
            r_MI_obs, obs_rew=self.get_obs_MI(gru_out_c,obs_h,obs_latent_embed,b,a)
            r_MI=self.args.beta1*(r_MI_s+self.args.beta3*KLD_s.reshape(b,a))+self.args.beta2*(r_MI_obs+self.args.beta3*KLD_obs.reshape(b,a))
            r_MI=r_MI.clamp(max=20)


        g_latent_para = self.g_params_net(indi_latent_dis)
        id_latent_para = self.g_params_net(id_indi_latent_dis)

        latent_para = self.params_dis_net(th.cat([g_latent_para, id_latent_para], dim=1))
        # if True:
        #     mu_sigma_s=indi_mu.detach().cpu().numpy().reshape(-1)
        #     self.save_mu.append(mu_sigma_s)
        #     #np.savetxt('results/obs_mu_sigma_57_4_2_64.txt',self.save_mu)
        #     #np.savetxt('results/MMM2_mu_110_57_10_64.txt', self.save_mu)
        fc2_w = self.fc2_w_nn(latent_para) # b*a,e*n_actions
        fc2_b = self.fc2_b_nn(latent_para)# b*a,n_actions
        fc2_w = fc2_w.view(b*a,self.args.rnn_hidden_dim,self.args.n_actions)# b*a, e,n_actions
        fc2_b = fc2_b.view(b*a,1,self.args.n_actions)# b*a, 1,n_actions
        h = gru_out.reshape(b*a,1,self.args.rnn_hidden_dim) # b*a,1,e
        q_2 = torch.matmul(h,fc2_w) + fc2_b  #.clone().detach()
        q_1 = self.fc2(h)
        q=q_2+q_1
        q = q.reshape(obs.shape[0], obs.shape[1], -1)  # 415
        if excute_label:
            return q.squeeze(),gru_out.permute(1,0,2)
        return q, gru_out, r_MI,s_rew,obs_rew
    def reparametrize(self, mu, logvar):
        # std = logvar.mul(0.5).exp_()
        # if self.args.cuda:
        #     eps = Variable(torch.randn(mu.size(0), mu.size(1))).cuda()
        #     # eps = torch.cuda.FloatTensor(mu.size()).normal_()
        # else:
        origin_size = mu.size()
        mu = mu.view(-1,mu.size(-1))
        logvar = logvar.view(-1,logvar.size(-1))

        eps = Variable(torch.randn(mu.size(0), mu.size(1))).to(mu.device) #标准正太分布
        #logvar: ln var^2
            # eps = torch.FloatTensor(mu.size()).normal_()
        z = mu + 0.001*eps*torch.exp(logvar/2)
        z=z.view(origin_size)
        return z

    def update(self, inputs,vae_mu,vae_sigma, mask):
        pass

class RNN(nn.Module):
    # Because all the agents share the same network, input_shape=obs_shape+n_actions+n_agents
    def __init__(self, input_shape, args):
        super(RNN, self).__init__()
        self.args = args
        self.n_agents=args.n_agents
        self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim)
        self.rnn = nn.GRU(
            input_size=args.rnn_hidden_dim,
            num_layers=1,
            hidden_size=args.rnn_hidden_dim,
            batch_first=True,
        )
        self.fc2 = nn.Linear(args.rnn_hidden_dim, args.n_actions)
    def forward(self, obs, hidden_state):
        excute_label = False
        if len(hidden_state.shape) == 2:
            hidden_state = hidden_state.unsqueeze(0)  # 保证尺寸(1,3,64)
        if obs.shape[0]== self.n_agents:
            excute_label = True

        #########################################

        obs_c = obs.view(-1, obs.shape[-1])

        x = f.relu(self.fc1(obs_c))
        x = x.reshape(obs.shape[0], obs.shape[1], -1)
        h_in = hidden_state
        gru_out, _ = self.rnn(x, h_in)
        gru_out_c = gru_out.reshape(-1, gru_out.shape[-1])
        q = self.fc2(gru_out_c)
        q = q.reshape(obs.shape[0], obs.shape[1], -1)  # 415
        if excute_label:
            return q.squeeze(),gru_out.permute(1,0,2)
        return q, gru_out
    #
    # def forward(self, obs, hidden_state):
    #     x = f.relu(self.fc1(obs))
    #     h_in = hidden_state.reshape(-1, self.args.rnn_hidden_dim)
    #     h = self.rnn(x, h_in)
    #     q = self.fc2(h)
    #     return q, h


# Critic of Central-V
class Critic(nn.Module):
    def __init__(self, input_shape, args):
        super(Critic, self).__init__()
        self.args = args
        self.fc1 = nn.Linear(input_shape, args.critic_dim)
        self.fc2 = nn.Linear(args.critic_dim, args.critic_dim)
        self.fc3 = nn.Linear(args.critic_dim, 1)

    def forward(self, inputs):
        x = f.relu(self.fc1(inputs))
        x = f.relu(self.fc2(x))
        q = self.fc3(x)
        return q
class MLP(nn.Module):

    def __init__(self, args):
        super(MLP, self).__init__()
        self.args = args
        self.fc = nn.Linear(args.rnn_hidden_dim, args.n_actions)

    def forward(self, hidden_state):
        q = self.fc(hidden_state)
        return q
